function EFX = moment_EFX_conditional_Model20220926(parms,gesol,F,X_Nxz)
% computes the conditional moment EFX(N,H,z)=E[F(X(N,H,z)+X(N',H',z')+X(N'',H'',z''))|N,z]
% inputs: 
% 1. parameters
% 2. Model solution, used to determine N'=N'(N,z')=(1-s(z'))*N + f(N,z')*(1-N)
% 3. F = function handle. For example, F=@(X)log(X)
% 4. X_NHz = matrix of gridpoints for X(N,H,z) ordered according to (N,H,z)

    EFX = zeros(parms.NN, parms.Nx, parms.Nz);

    Nmin = parms.grid_N(1); xmin = parms.grid_x(1);
    Nmax = parms.grid_N(end); xmax = parms.grid_x(end);
    
    N0_Nx = repmat(parms.grid_N(:),[1 parms.Nx]);
    x0_Nx = permute(repmat(parms.grid_x(:),[1 parms.NN]),[2 1]);
    % sigma_x_Nxz = permute(repmat(parms.sigma_x,[1 1 parms.Nx]),[1 3 2]);
    sigma_x0_N0z0 = parms.sigma_x;
    sigma_x1_N0x0z0z1 = zeros(parms.NN,parms.Nx);

    expected_znext = parms.StateTransitionProbs*parms.grid_z(:);
    if parms.OPTION_x_standardize_shock
        std_znext = sqrt(parms.StateTransitionProbs*(parms.grid_z(:).^2) - expected_znext.^2);
    else
        std_znext = ones(size(expected_znext));
    end
    
    N1 = zeros(parms.NN,parms.Nx); N2 = zeros(parms.NN,parms.Nx);
    x1 = zeros(parms.NN,parms.Nx); x2 = zeros(parms.NN,parms.Nx);
    X0 = zeros(parms.NN,parms.Nx);
    X1 = zeros(parms.NN,parms.Nx);
    X2 = zeros(parms.NN,parms.Nx);
    
    [N_mesh,x_mesh] = meshgrid(parms.grid_N(:),parms.grid_x(:));
    
    for iz0 = 1:parms.Nz
        N1(:,:) = (1-parms.s(iz0))*N0_Nx + gesol.f(:,:,iz0).*(1-N0_Nx); % N1=N1(N0,H0) for given z0
        N1(N1<Nmin) = Nmin;
        N1(N1>Nmax) = Nmax;
        X0(:,:) = X_Nxz(:,:,iz0);
        for iz1 = 1:parms.Nz
            x1(:,:) = (1-parms.rho_x)*parms.x_bar + parms.rho_x*x0_Nx ...
                + repmat(sigma_x0_N0z0(:,iz0),[1 parms.Nx])*(parms.grid_z(iz1) - expected_znext(iz0))/std_znext(iz0);
            x1(x1<xmin) = xmin;
            x1(x1>xmax) = xmax;
            N2(:,:) = (1-parms.s(iz1))*N1 ...
                + interp2(N_mesh,x_mesh,gesol.f(:,:,iz1)',N1,x1).*(1-N1); % N1=N1(N0,H0) for given z0
            N2(N2<Nmin) = Nmin;
            N2(N2>Nmax) = Nmax;
            X1(:,:) = interp2(N_mesh,x_mesh,X_Nxz(:,:,iz1)',N1,x1);
            sigma_x1_N0x0z0z1(:,:) = interp1(parms.grid_N(:),parms.sigma_x(:,iz1),N1);
            for iz2 = 1:parms.Nz
                x2(:,:) = (1-parms.rho_x)*parms.x_bar + parms.rho_x*x1 ...
                + sigma_x1_N0x0z0z1*(parms.grid_z(iz2) - expected_znext(iz1))/std_znext(iz1);
                x2(x2<xmin) = xmin;
                x2(x2>xmax) = xmax;
                X2(:,:) = interp2(N_mesh,x_mesh,X_Nxz(:,:,iz2)',N2,x2);
                path_prob = parms.StateTransitionProbs(iz0,iz1)*parms.StateTransitionProbs(iz1,iz2);
                EFX(:,:,iz0) = EFX(:,:,iz0) + path_prob*F(X0 + X1 + X2);
            end
        end
    end

end